#This file produces images for the paper
rm(list=ls())
setwd("...") # Fill in working directory 
library(dplyr)
library(ggplot2)   
library(Hmisc)      
library(plyr)
library(pROC)
library(gridExtra)
library(rlang)
library(cowplot)

#Data Prep----
older = read.csv("older_cohort_results.csv", stringsAsFactors = TRUE)%>%
  select(-"X")
younger = read.csv("younger_cohort_results.csv", stringsAsFactors = TRUE)%>%
  select(-"X")

older$arrested17_24 = relevel(older$arrested17_24, "never_arrested")
younger$arrested17_24 = relevel(younger$arrested17_24, "never_arrested")

older$sp_ethn_aw = relevel(older$sp_ethn_aw, "White and Other")
younger$sp_ethn_aw = relevel(younger$sp_ethn_aw, "White and Other")

older_outcomes = older$arrested17_24=="arrested"
younger_outcomes = younger$arrested17_24=="arrested"


older_CH = read.csv("older_cohort_binCH_results.csv", stringsAsFactors = TRUE)%>%
  select(-"X")
younger_CH = read.csv("younger_cohort_binCH_results.csv", stringsAsFactors = TRUE)%>%
  select(-"X")

older_CH$arrested19_24 = relevel(older_CH$arrested19_24, "never_arrested")
younger_CH$arrested19_24 = relevel(younger_CH$arrested19_24, "never_arrested")

older_CH$arrested17_18 = relevel(older_CH$arrested17_18, "never_arrested")
younger_CH$arrested17_18 = relevel(younger_CH$arrested17_18, "never_arrested")

older_CH$sp_ethn_aw = relevel(older_CH$sp_ethn_aw, "White and Other")
younger_CH$sp_ethn_aw = relevel(younger_CH$sp_ethn_aw, "White and Other")

older_CH_outcomes = older_CH$arrested19_24=="arrested"
younger_CH_outcomes = younger_CH$arrested19_24=="arrested"

#Function Prep----
get_train_test_ROC = function(train_prob, train_labels, test_prob=NULL, test_labels=NULL, spec = 0.005){
  if(length(train_prob) != length(train_labels)){
    stop("Training predictions and labels must be of the same length")
  }
  if(is_empty(test_prob)!=is_empty(test_labels)){
    stop("Either both predictions and labels for test data must be supplied or neither must be supplied")
  }
  if(length(test_prob) != length(test_labels)){
    stop("Test predictions and labels must be of the same length")
  }
  breaks = seq(0,1,spec)
  num_points = length(breaks)
  if(is_empty(test_prob)){
    roc_points = data.frame(matrix(vector(), 0, 2, dimnames = list(c(), c("train_tpr", "train_fpr"))))
  }
  else{
    roc_points = data.frame(matrix(vector(), 0, 4, dimnames = list(c(), c("train_tpr", "train_fpr", "test_tpr", "test_fpr"))))
  }
  
  for(i in 1:num_points){
    threshold = breaks[i]
    train_pred = train_prob >= threshold
    train_tpr = sum(train_pred + train_labels == 2)/sum(train_labels)
    train_fpr = sum(train_pred==1 & train_labels == 0)/(length(train_labels)-sum(train_labels))
    
    if(is_empty(test_prob)){
      roc_points[i,] = c(train_tpr, train_fpr)
    }
    else{
      test_pred = test_prob >= threshold
      test_tpr = sum(test_pred + test_labels == 2)/sum(test_labels)
      test_fpr = sum(test_pred==1 & test_labels == 0)/(length(test_labels) - sum(test_labels))
      roc_points[i,] = c(train_tpr, train_fpr, test_tpr, test_fpr)
    }
  }
  return(roc_points[rev(rownames(roc_points)),]
  )
  
}

plot_roc = function(train_prob, train_labels, test_prob=NULL, test_labels=NULL, spec = 0.005, legend_pos = "none"){
  df = get_train_test_ROC(train_prob, train_labels, test_prob, test_labels, spec)
  train_auc = round(auc(response = train_labels, predictor = train_prob),3)
  test_auc = round(auc (response = test_labels, predictor = test_prob),3)
  ggplot(data = df) +
    geom_line(mapping = aes(x = train_fpr, y = train_tpr, color = "red")) +
    geom_line(mapping = aes(x = test_fpr, y = test_tpr, color = "blue"),linetype = "dashed") +
    labs(x = "False Positive Rate", y = "True Positive Rate", color = "") +
    scale_color_identity(labels = c("Older Cohort ROC", "Younger Cohort ROC"),
                         breaks = c("red", "blue"),
                         guide = "legend") +
  #  geom_abline(slope = 1, intercept = 0) +
    geom_segment(aes(x=0, y=0, yend=1,xend=1)) +
    geom_text(x = 1, y = 0.15, label = paste("Older Cohort AUC:", train_auc), color = "red", hjust =1, size=2.63)+
    geom_text(x = 1, y = 0.07, label = paste("Younger Cohort AUC:", test_auc), color = "blue", hjust = 1, size=2.63)+
    theme_bw()+
    theme(legend.position = legend_pos)
  
}

lm_eq = function(df , y, x){
  m <- lm(y ~ x+0, df)
  
  eq <- substitute(y == a %.% x*","~~r^2~"="~r2,
                   list(a = format(unname(coef(m)[1]), digits = 2),
                        r2 = format(summary(m)$r.squared, digits = 3)))
  as.character(as.expression(eq))
}

#Fig 1----
p1 = plot_roc(older$pnas_lr, older_outcomes, younger$pnas_lr, younger_outcomes)
#p2 = plot_roc(older$ks_lasso, older_outcomes, younger$ks_lasso, younger_outcomes) 

fig1 <- plot_grid(p1, # p2, 
                  align="h", axis="b", labels="AUTO")

save_plot(filename = "fig1.pdf", plot = fig1, device = "pdf", base_width = 2.25, base_height=2.25, dpi=1000)


#Fig 2----
y1 <- younger %>% transmute(x = pnas_lr, y= as.numeric(arrested17_24=="arrested"),
                            cohort="Performance of Younger Cohort", Model = "Classic Risk-Factor Logistic Regression")
temp <- lm_eq(y1)
y1text <- data.frame(y1[1,], temp)
#y2 <- younger %>% transmute(x = ks_lasso, y= as.numeric(arrested17_24=="arrested"),
#                            cohort="Performance of Younger Cohort", Model = "Full Lasso Logistic Regression")
#temp <- lm_eq(y2)
#y2text <- data.frame(y2[1,], temp)

o1 <- older %>% transmute(x = pnas_lr, y= as.numeric(arrested17_24=="arrested"),
                          cohort="Performance of Older Cohort",Model = "Classic Risk-Factor Logistic Regression")
temp <- lm_eq(o1)
o1text <- data.frame(o1[1,], temp)
#o2 <- older %>% transmute(x = ks_lasso, y= as.numeric(arrested17_24=="arrested"),
#                          cohort="Performance of Older Cohort", Model = "Full Lasso Logistic Regression")
#temp <- lm_eq(o2)
#o2text <- data.frame(o2[1,], temp)

combo <- rbind(y1, o1 #, 
              # y2, o2
               )
text_layer <- rbind(y1text,
                   # y2text,
                    o1text #,
                    #o2text
                    )
text_layer$x <- .4
text_layer$y <- .9

fig2 <- ggplot(data = combo, mapping = aes(x = x, y = y, label=temp)) + 
 # geom_abline(slope = 1, intercept = 0) +
  geom_segment(aes(x=0, y=0, yend=1,xend=1)) +
  labs(x = "Predicted Probability of Arrest Between Ages 17 and 24",
       y = "Prop. Arrested Between Ages 17 and 24")+
  geom_smooth(method = "lm", se = FALSE, formula=y~0+x,color = "deeppink") +
  theme_bw()+ coord_cartesian(xlim = c(0,1), ylim = c(0,1)) +
  geom_text(data=text_layer, parse = T) +facet_grid(Model ~ cohort)
  
ggsave("fig2.pdf", fig2, width = 6, height = 3, dpi = 1000)


#Fig 3----
#fig3 = ggplot(mapping = aes(x = rank(younger$ks_lasso), y = rank(younger$ks_lasso_yy)))+
#  geom_point()+
#  #geom_abline(slope = 1, intercept = 0)+
#  geom_segment(aes(x=0, y=0, yend=nrow(younger),xend=nrow(younger))) +
#  theme_bw()+
#  labs(x = "Rank from Model Trained on Older Cohort", y = "Rank from Model Trained on Younger Cohort", title = "")
#fig3
#ggsave("fig3.pdf", fig3, width = 4, height = 4, dpi=1000)


#Fig 4----
younger$panel <- "Performance on Younger Cohort"
older$panel <- "Performance on Older Cohort"
y2 <- younger %>% transmute(panel,x = pnas_lr,y =as.numeric(arrested17_24=="arrested"),sp_ethn_aw)

y_w <- lm_eq(y2[y2$sp_ethn_aw=="White and Other",])
y_b <- lm_eq(y2[y2$sp_ethn_aw=="Black",])
y_h <- lm_eq(y2[y2$sp_ethn_aw=="Latino",])

o2 <- older %>% transmute(panel,x = pnas_lr,y =as.numeric(arrested17_24=="arrested"),sp_ethn_aw)

o_w <- lm_eq(o2[o2$sp_ethn_aw=="White and Other",])
o_b <- lm_eq(o2[o2$sp_ethn_aw=="Black",])
o_h <- lm_eq(o2[o2$sp_ethn_aw=="Latino",])

combo <- rbind(y2,o2)

text_layer <- data.frame(temp = c(y_w,y_b,y_h,o_w, o_b, o_h))
text_layer$x <- c(.4,.4,.4,.35,.41,.35)
text_layer$y <- c(.82,.82,.82,.95,.95,.95)
text_layer$sp_ethn_aw <- rep(c("White and Other", "Black", "Latino"), 2)
text_layer$panel <- rep(c("Performance on Younger Cohort", "Performance on Older Cohort"), each=3)  

text_layer$sp_ethn_aw <-  factor(text_layer$sp_ethn_aw, levels = levels(combo$sp_ethn_aw))

fig4 <- ggplot(data = combo, mapping = aes(x = x, y = y, color = panel, linetype=panel))+
    #geom_abline(slope = 1, intercept = 0) +
  geom_segment(aes(x=0, y=0, yend=1,xend=1), color="black") +
    labs(x = "Predicted Probability of Arrest Between Ages 17 and 24", 
         y = "Proportion Arrested Between Ages 17 and 24")+
    geom_smooth(method = "lm", formula = y~x+0, se = FALSE) +
    theme_bw()+  scale_color_manual(values = c("red", "blue")) +
    theme(legend.position = "bottom", legend.key.width = unit(2.5, "cm"),
          legend.title = element_blank()) +
    facet_grid(~sp_ethn_aw) + coord_cartesian(xlim = c(0,1), ylim = c(0,1)) +
    geom_text(aes(label=temp), data=text_layer, parse = T, show.legend = F)

ggsave("fig4.pdf", fig4, width = 8.2, height = 4, dpi=1000)


#Fig 5----
y1 <- younger_CH %>% transmute(x = pnas_lr, y= as.numeric(arrested19_24=="arrested"),
                            cohort="Performance of Younger Cohort", Model = "Classic Risk-Factor Logistic Regression")
temp <- lm_eq(y1)
y1text <- data.frame(y1[1,], temp)
#y2 <- younger_CH %>% transmute(x = ks_lasso, y= as.numeric(arrested19_24=="arrested"),
#                            cohort="Performance of Younger Cohort", Model = "Full Lasso Logistic Regression")
#temp <- lm_eq(y2)
#y2text <- data.frame(y2[1,], temp)

o1 <- older_CH %>% transmute(x = pnas_lr, y= as.numeric(arrested19_24=="arrested"),
                          cohort="Performance of Older Cohort",Model = "Classic Risk-Factor Logistic Regression")
temp <- lm_eq(o1)
o1text <- data.frame(o1[1,], temp)

#o2 <- older_CH %>% transmute(x = ks_lasso, y= as.numeric(arrested19_24=="arrested"),
#                          cohort="Performance of Older Cohort", Model = "Full Lasso Logistic Regression")
#temp <- lm_eq(o2)
#o2text <- data.frame(o2[1,], temp)

combo <- rbind(y1, o1 # , y2, o2
               )
text_layer <- rbind(y1text,
                    # y2text,
                    o1text # ,o2text
                    )
text_layer$x <- .4
text_layer$y <- .9

fig5 <- ggplot(data = combo, mapping = aes(x = x, y = y, label=temp)) + 
  #geom_abline(slope = 1, intercept = 0) +
  geom_segment(aes(x=0, y=0, yend=1,xend=1)) +
  labs(x = "Predicted Probability of Arrest Between Ages 19 and 24",
       y = "Prop. Arrested Between Ages 19 and 24")+
  geom_smooth(method = "lm", se = FALSE, formula=y~0+x,color = "deeppink") +
  theme_bw()+ coord_cartesian(xlim = c(0,1), ylim = c(0,1)) +
  geom_text(data=text_layer, parse = T) +facet_grid(Model ~ cohort)

ggsave("fig5.pdf", fig5, width = 6, height = 3, dpi = 1000)


#Fig 6----
y1 <- younger[younger$Zage15cblc_lowSC>=2,] %>% transmute(x = pnas_lr, y= as.numeric(arrested17_24=="arrested"),
                            cohort="Performance of Younger Cohort", Model = "Classic Risk-Factor Logistic Regression")
temp <- lm_eq(y1)
y1text <- data.frame(y1[1,], temp)

#y2 <- younger[younger$Zage15cblc_lowSC>=2,] %>% transmute(x = ks_lasso, y= as.numeric(arrested17_24=="arrested"),
#                            cohort="Performance of Younger Cohort", Model = "Full Lasso Logistic Regression")
#temp <- lm_eq(y2)
#y2text <- data.frame(y2[1,], temp)

o1 <- older[older$Zage15cblc_lowSC>=2,] %>% transmute(x = pnas_lr, y= as.numeric(arrested17_24=="arrested"),
                          cohort="Performance of Older Cohort",Model = "Classic Risk-Factor Logistic Regression")
temp <- lm_eq(o1)
o1text <- data.frame(o1[1,], temp)

#o2 <- older[older$Zage15cblc_lowSC>=2,] %>% transmute(x = ks_lasso, y= as.numeric(arrested17_24=="arrested"),
#                          cohort="Performance of Older Cohort", Model = "Full Lasso Logistic Regression")
#temp <- lm_eq(o2)
#o2text <- data.frame(o2[1,], temp)

combo <- rbind(y1, o1 #, y2, o2
               )
text_layer <- rbind(y1text, # y2text,
                    o1text # ,o2text
                    )
text_layer$x <- .4
text_layer$y <- .9

fig6 <- ggplot(data = combo, mapping = aes(x = x, y = y, label=temp)) + 
  #geom_abline(slope = 1, intercept = 0) +
  geom_segment(aes(x=0, y=0, yend=1,xend=1)) +
  labs(x = "Predicted Probability of Arrest Between Ages 17 and 24",
       y = "Prop. Arrested Between Ages 17 and 24")+
  geom_smooth(method = "lm", se = FALSE, formula=y~0+x,color = "deeppink") +
  theme_bw()+ coord_cartesian(xlim = c(0,1), ylim = c(0,1)) +
  geom_text(data=text_layer, parse = T) +facet_grid(Model ~ cohort)

ggsave("fig6.pdf", fig6, width = 6, height = 3, dpi = 1000)
